// This file is licensed under the Elastic License 2.0. Copyright 2021-present, StarRocks Limited.

package com.starrocks.sql.optimizer;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.starrocks.common.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * The space of plan alternatives generated by the optimizer is encoded in
 * a compact in-memory data structure called the Memo.
 * <p>
 * Memo Provides memoization, duplicate detection, and property + cost management.
 * <p>
 * This is key component of our optimizer. We utilize
 * dynamic programing to search optimal query plan. When
 * searching, sub-problem's result is needed to be stored.
 * We store all GroupExpression that have been searched in this struct.
 * <p>
 * Memo store all GroupExpression in a hash set,
 * which is efficient to look up if there is already GroupExpression.
 * <p>
 * All the group should be created from this class.
 */
public class Memo {
    private static final Logger LOG = LogManager.getLogger(Memo.class);

    private int nextGroupId = 0;

    // The group id is same with the group index in groups List
    private final List<Group> groups;

    private Group rootGroup;
    /**
     * The map value is root group id for the GroupExpression.
     * We need to store group id because when {@see insertGroupExpression}
     * we need to get existed group id for tmp GroupExpression,
     * which doesn't have group id info
     */
    private final Map<GroupExpression, GroupExpression> groupExpressions;

    public List<Group> getGroups() {
        return groups;
    }

    public Map<GroupExpression, GroupExpression> getGroupExpressions() {
        return groupExpressions;
    }

    public Memo() {
        groups = Lists.newLinkedList();
        groupExpressions = Maps.newHashMap();
    }

    public Group getRootGroup() {
        return rootGroup;
    }

    /**
     * Copy an expression into search space, this function will add an GroupExpression for
     * this Expression. If this Expression has children, this function will be called
     * recursively to create GroupExpression and Group for every single Expression
     * For example, Join(Scan(A), Scan(B)) will create 3 Groups and GroupExpressions for Join,
     * Scan(A) and Scan(B).
     * We return GroupExpression rather than Group because we can get Group from GroupExpression
     */
    public GroupExpression init(OptExpression originExpression) {
        Preconditions.checkState(groups.size() == 0);
        Preconditions.checkState(groupExpressions.size() == 0);
        GroupExpression rootGroupExpression = copyIn(null, originExpression).second;
        rootGroup = rootGroupExpression.getGroup();
        return rootGroupExpression;
    }

    public Pair<Boolean, GroupExpression> insertGroupExpression(GroupExpression groupExpression, Group targetGroup) {
        if (groupExpressions.get(groupExpression) != null) {
            GroupExpression existedGroupExpression = groupExpressions.get(groupExpression);
            Group existedGroup = existedGroupExpression.getGroup();

            if (needMerge(targetGroup, existedGroup)) {
                mergeGroup(existedGroup, targetGroup);
            }

            return new Pair<>(false, existedGroupExpression);
        }

        if (targetGroup == null) {
            targetGroup = newGroup();
            groups.add(targetGroup);
        }

        groupExpressions.put(groupExpression, groupExpression);

        targetGroup.addExpression(groupExpression);

        return new Pair<>(true, groupExpression);
    }

    /**
     * Insert an enforce expression into the target group.
     */
    public void insertEnforceExpression(GroupExpression groupExpression, Group targetGroup) {
        groupExpression.setGroup(targetGroup);
    }

    private Group newGroup() {
        return new Group(nextGroupId++);
    }

    public Pair<Boolean, GroupExpression> copyIn(Group targetGroup, OptExpression expression) {
        List<Group> inputs = Lists.newArrayList();
        for (OptExpression input : expression.getInputs()) {
            Group group;
            if (input.getGroupExpression() != null) {
                group = input.getGroupExpression().getGroup();
            } else {
                group = copyIn(null, input).second.getGroup();
            }
            Preconditions.checkState(group != null);
            Preconditions.checkState(group != targetGroup);
            inputs.add(group);
        }

        GroupExpression groupExpression = new GroupExpression(expression.getOp(), inputs);
        Pair<Boolean, GroupExpression> result = insertGroupExpression(groupExpression, targetGroup);
        if (result.first && targetGroup == null) {
            // For new group, we need drive property from expression
            // add set it to new group
            Preconditions.checkState(result.second.getOp().isLogical());
            result.second.deriveLogicalPropertyItself();

            // For multi join reorder,
            // We have derived statistics In ReorderJoinRule
            result.second.getGroup().setStatistics(expression.getStatistics());
        }
        return result;
    }

    private boolean needMerge(Group targetGroup, Group existedGroup) {
        return targetGroup != null && targetGroup != existedGroup;
    }

    private void mergeGroup(Group srcGroup, Group dstGroup) {
        mergeGroupImpl(srcGroup, dstGroup);
        // When some rule merge two groups to one group, or
        // the GroupExpressions of one group are all removed.
        // The group is empty, We should remove it.
        Set<Group> groups = getAllEmptyGroups();
        for (Group group : groups) {
            removeOneGroup(group);
        }
    }

    // Merge srcGroup to dstGroup, srcGroup will be deleted
    private void mergeGroupImpl(Group srcGroup, Group dstGroup) {
        groups.remove(srcGroup);
        // Reset root group, rewrite rule maybe eliminate the root group
        if (srcGroup == rootGroup) {
            rootGroup = dstGroup;
        }

        // If we change the GroupExpression child group, the hash value of GroupExpression
        // will change, so we must reinsert the GroupExpression to groupExpressions map
        List<GroupExpression> needModifyExpressions = Lists.newArrayList();
        for (Iterator<Map.Entry<GroupExpression, GroupExpression>> iterator = groupExpressions.entrySet().iterator();
                iterator.hasNext(); ) {
            GroupExpression groupExpr = iterator.next().getKey();

            // 1. find GroupExpression which refer to src group, and remove them from memo
            for (Group group : groupExpr.getInputs()) {
                if (group == srcGroup) {
                    // multi-input must not same
                    iterator.remove();
                    needModifyExpressions.add(groupExpr);
                    break;
                }
            }

            // 2. find GroupExpression which on src group
            if (groupExpr.getGroup() == srcGroup) {
                needModifyExpressions.add(groupExpr);
            }
        }

        // modify group of group expression and mark who need reinsert
        List<GroupExpression> needReinsertedExpressions = Lists.newArrayList();
        for (GroupExpression modifyExpression : needModifyExpressions) {
            if (modifyExpression.getGroup() == srcGroup) {
                modifyExpression.setGroup(dstGroup);
            }

            for (int i = 0; i < modifyExpression.getInputs().size(); i++) {
                if (modifyExpression.getInputs().get(i) == srcGroup) {
                    // remove self from his group, and reinsert later
                    modifyExpression.getGroup().removeGroupExpression(modifyExpression);
                    modifyExpression.getInputs().set(i, dstGroup);
                    needReinsertedExpressions.add(modifyExpression);
                }
            }
        }

        Map<Group, Group> needMergeGroup = Maps.newHashMap();
        for (GroupExpression reinsertExpression : needReinsertedExpressions) {
            // reinsert maybe in groupExpressions because his input was modify
            if (!groupExpressions.containsKey(reinsertExpression)) {
                groupExpressions.put(reinsertExpression, reinsertExpression);
                reinsertExpression.getGroup().addExpression(reinsertExpression);
            } else {
                // group expression is already in the Memo's groupExpressions, this indicates that
                // this is a redundant group Expression, it's should be remove.
                // And the redundant group expression may be already in the TaskScheduler stack, so it should be
                // set unused.
                reinsertExpression.setUnused(true);
                GroupExpression existGroupExpression = groupExpressions.get(reinsertExpression);
                if (!needMerge(reinsertExpression.getGroup(), existGroupExpression.getGroup())) {
                    // groupExpression and existGroupExpression are in the same group，use existGroupExpression to
                    // replace the bestExpression in the group
                    reinsertExpression.getGroup().replaceBestExpression(reinsertExpression, existGroupExpression);
                    // existingGroupExpression merge the state of groupExpression
                    existGroupExpression.mergeGroupExpression(reinsertExpression);
                } else {
                    // reinsertExpression and existGroupExpression are not in the same group, need to merge them.
                    reinsertExpression.getGroup().deleteBestExpression(reinsertExpression);
                    needMergeGroup.put(reinsertExpression.getGroup(), existGroupExpression.getGroup());
                }
            }
        }
        dstGroup.mergeGroup(srcGroup);

        needMergeGroup.forEach(this::mergeGroupImpl);
    }

    private Set<Group> getAllEmptyGroups() {
        Set<Group> groups = Sets.newHashSet();
        for (Group group : getGroups()) {
            if (group.isEmpty()) {
                groups.add(group);
                continue;
            }
            for (Group childGroup : group.getFirstLogicalExpression().getInputs()) {
                if (childGroup.isEmpty()) {
                    groups.add(childGroup);
                    break;
                }
            }
        }
        return groups;
    }

    public void removeAllEmptyGroup() {
        Set<Group> groups = getAllEmptyGroups();
        while (!groups.isEmpty()) {
            for (Group group : groups) {
                removeOneGroup(group);
            }
            groups = getAllEmptyGroups();
        }
    }

    private void removeOneGroup(Group group) {
        groups.remove(group);

        for (Iterator<Map.Entry<GroupExpression, GroupExpression>>
                iterator = groupExpressions.entrySet().iterator(); iterator.hasNext(); ) {
            GroupExpression groupExpr = iterator.next().getKey();
            if (groupExpr.getGroup() == group) {
                iterator.remove();
                continue;
            }
            for (int i = 0; i < groupExpr.getInputs().size(); i++) {
                if (groupExpr.getInputs().get(i) == group) {
                    groupExpr.getGroup().removeGroupExpression(groupExpr);
                    iterator.remove();
                    break;
                }
            }
        }
    }

    private void deepSearchGroup(Group root, LinkedList<Integer> touch) {
        for (Group group : root.getFirstLogicalExpression().getInputs()) {
            touch.add(group.getId());
            deepSearchGroup(group, touch);
        }
    }

    /*
     * @Note: The function only work in logical rewrite phase !!!
     *
     * When performing replaceRewriteExpression, some groups may not be reachable by rootGroup.
     * These groups should be replaced.
     * In order to reduce the number of groups entering Memo,
     * we will delete inaccessible groups in this function.
     */
    public void removeUnreachableGroup() {
        LinkedList<Integer> touch = new LinkedList<>();
        touch.add(rootGroup.getId());
        deepSearchGroup(rootGroup, touch);
        groups.removeIf(g -> !touch.contains(g.getId()));
        groupExpressions.clear();

        // only used in logical rewrite phase, logical expression must only one in a group,
        // and remove groupExpression one by one is too slow, so rebuild directly
        for (Group group : groups) {
            group.getLogicalExpressions().forEach(l -> groupExpressions.put(l, l));
        }
    }

    // For rewrite rule, we directly replace the old group expression by new expression
    public void replaceRewriteExpression(Group targetGroup, OptExpression expression) {
        removeGroupInitLogicExpression(targetGroup);
        GroupExpression groupExpression = copyIn(targetGroup, expression).second;

        // For group has rewritten, we need drive property from expression again
        groupExpression.deriveLogicalPropertyItself();
    }

    private void removeGroupInitLogicExpression(Group group) {
        GroupExpression initGroupExpression = group.getFirstLogicalExpression();
        groupExpressions.remove(initGroupExpression);

        Preconditions.checkState(group.isValidInitState());

        group.getLogicalExpressions().clear();
    }

    public void deriveAllGroupLogicalProperty() {
        getRootGroup().getFirstLogicalExpression().deriveLogicalPropertyRecursively();
    }
}
